from experiment_utils.largestconnectedcomponent import lcc_dataset
from utils.load_datasets import load_data,data_information
from experiment_utils.sdrf_cuda import sdrf_BFc,sdrf_JTc,sdrf_JLc,sdrf_AFc
from utils.seeds import val_seeds
from utils.splits import set_train_val_test_split,set_train_val_test_split_frac
from experiment_utils.experimentclass import Experiment

from torch_geometric.data import Data
import torch
import torch.nn.functional as F
import torch_geometric
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device: ", device)
import numpy as np

from tqdm import tqdm 
import os
import json

import wandb


"""
Parameters for the experiment
"""

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS"] = "false"


datasetname = "Cornell"
results_dir = "results"
rewiring_run = True
make_undirected = True
int_node = False
Curvature_type = "BFc_w4cycle"

path = ""

dataset,data,G = load_data(datasetname)
dataset_lcc = lcc_dataset(dataset,to_undirected = make_undirected)
data_lcc = dataset_lcc[0]

data_information(dataset_lcc,data_lcc)


def create_rewired_edge_index(data,hyperparameters,intermediate_node,remove_edges,curvaturetype: str ):
    if curvaturetype == "BFc_w4cycle":
        G_rewired,_ = sdrf_BFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            int_node = intermediate_node,
            is_undirected=data.is_undirected(),
            fcc = True,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "BFc_no4cycle":
        G_rewired,_ = sdrf_BFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            int_node = intermediate_node,
            is_undirected=data.is_undirected(),
            fcc = False,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "JTc":
        G_rewired,_ = sdrf_JTc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "JLc":
        G_rewired,_ = sdrf_JLc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"], 
            is_undirected=data.is_undirected(),
            progress_bar = False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "AFc_3":
        G_rewired,_ = sdrf_AFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=-hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            k = 3.,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "AFc_4":
        G_rewired,_ = sdrf_AFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=-hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            k = 4,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    
    return G_rewired,edge_index_rewired 


def objective(config,rewire = False):
    val_acc = []
    test_acc = []
    if rewire:
        print("===Starting Rewiring===")
        G_new ,edge_index_rewired = create_rewired_edge_index(data_lcc,config,intermediate_node=int_node,remove_edges=True,curvaturetype=Curvature_type)
        print(" ")

    print(" == Starting Runs == ")
    for idx_k,k in tqdm(enumerate(val_seeds[:2])):

        if datasetname == "Cora" or datasetname == "Citeseer" or datasetname == "Pubmed":
            data_undirected_split = set_train_val_test_split(k,data_lcc)
        else:
            data_undirected_split = set_train_val_test_split_frac(k,data_lcc,0.2,0.2)

        if rewire:
            
            data_undirected_split.edge_index = edge_index_rewired

        data_undirected_split.to(device)

        Exp = Experiment(device,datasetname,dataset_lcc,data_undirected_split,config)

        
        counter = 0
        for epoch in range(1, Exp.epoch):
            loss = Exp.train()
            val = Exp.validate()
            #wandb.log({"loss " + str(idx_k): loss, "val " + str(idx_k): val,"epoch": epoch})
            if epoch ==1:
                best_val = val
            elif epoch > 1 and val > best_val:
                best_val = val
                counter = 0
            else:
                counter += 1
            if counter > 100:
                break  
        final_accuracy = Exp.validate()
        final_test_acc = Exp.test()
        val_acc.append(final_accuracy)
        test_acc.append(final_test_acc)
    print("")
    return np.mean(np.array(val_acc)),np.mean(np.array(test_acc))


config_hyper = {
        "learning_rate": 0.07468,
        "layers": [128],
        "dropout":0.214,
        "weight_decay":0.7837,
        "loops": 200,
        "C+": 16.881,
        "tau": 212
    } 

import time
start_time = time.time()

accuracies = objective(config_hyper,rewiring_run)
print("Average accuracy", accuracies)

print("--- %s seconds ---" % (time.time() - start_time))